{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {},
  "cells": [
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "   <a target=\"_blank\" href=\"https://labelbox.com\" ><img src=\"https://labelbox.com/blog/content/images/2021/02/logo-v4.svg\" width=256/></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "<a href=\"https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/huggingface/huggingface.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
        "</td>\n",
        "\n",
        "<td>\n",
        "<a href=\"https://github.com/Labelbox/labelbox-python/tree/master/examples/integrations/huggingface/huggingface.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white\" alt=\"GitHub\"></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Imports"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# for labelbox\n",
        "!pip install -q \"labelbox[data]\"\n",
        "import labelbox as lb\n",
        "# for custom embeddings in Labelbox\n",
        "!pip install -q 'git+https://github.com/Labelbox/advlib.git'\n",
        "#ndjson\n",
        "!pip install -q ndjson\n",
        "import ndjson\n",
        "import time\n",
        "import json"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Labelbox Credentials"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "API_KEY = \"\"\n",
        "\n",
        "client = lb.Client(API_KEY)\n",
        "\n",
        "# set LABELBOX_API_KEY in bash\n",
        "%env LABELBOX_API_KEY=$API_KEY"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Select data rows in Labelbox for custom embeddings"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# get images from a Labelbox dataset\n",
        "# Our systems start to process data after 1000 embeddings of each type, for this demo make sure your dataset is over 1000 data rows\n",
        "DATASET_ID = \"\""
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "dataset = client.create_dataset(DATASET_ID)\n",
        "\n",
        "data_row_ids = []\n",
        "data_row_urls = []"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "client.enable_experimental = True\n",
        "\n",
        "export_task = dataset.export()\n",
        "export_task.wait_till_done()\n"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Provide results with JSON converter\n",
        "# Returns streamed JSON output strings from export task results/errors, one by one\n",
        "\n",
        "# Callback used for JSON Converter\n",
        "def json_stream_handler(output: lb.JsonConverterOutput):\n",
        "  data_row = json.loads(output.json_str)\n",
        "  data_row_ids.append(data_row[\"data_row\"][\"id\"])\n",
        "  data_row_urls.append(data_row[\"data_row\"][\"row_data\"])\n",
        "\n",
        "\n",
        "if export_task.has_errors():\n",
        "  export_task.get_stream(\n",
        "  converter=lb.JsonConverter(),\n",
        "  stream_type=lb.StreamType.ERRORS\n",
        "  ).start(stream_handler=lambda error: print(error))\n",
        "\n",
        "if export_task.has_result():\n",
        "  export_json = export_task.get_stream(\n",
        "    converter=lb.JsonConverter(),\n",
        "    stream_type=lb.StreamType.RESULT\n",
        "  ).start(stream_handler=json_stream_handler)\n"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Get a HuggingFace Model to generate custom embeddings"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# import HuggingFace\n",
        "!pip3 install -q transformers\n",
        "!pip3 install -q timm\n",
        "\n",
        "# load a neural network from HuggingFace \n",
        "import transformers\n",
        "transformers.logging.set_verbosity(50)\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from PIL import Image\n",
        "import requests\n",
        "from tqdm import tqdm\n",
        "from concurrent.futures import ThreadPoolExecutor\n",
        "\n",
        "# get ResNet-50\n",
        "image_processor = transformers.AutoImageProcessor.from_pretrained(\"microsoft/resnet-50\")\n",
        "model = transformers.ResNetModel.from_pretrained(\"microsoft/resnet-50\")"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Pick an existing custom embedding in Labelbox, or create a custom embedding"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# See all custom embeddings available to your org\n",
        "!advtool embeddings list\n",
        "\n",
        "# Create a new custom embedding if needed\n",
        "# !advtool embeddings create ResNet50_2048_dimensions 2048 # anything between 8 and 2048 dimensions are supported\n",
        "# This returns the ID of the newly created embedding, e.g. 0ddc5d5c-0963-41ad-9c5d-5c0963a1ad98"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Generate and upload custom embeddings\n",
        "We generate and upload custom embeddings, in batches of 512 images to \n",
        "Labelbox.\n"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "custom_embeddings = {}\n",
        "print_debug = True\n",
        "batch_size = 512\n",
        "\n",
        "# iterate over images in batches of size 512\n",
        "for i in range(0, len(data_row_urls), batch_size):\n",
        "\n",
        "  try:\n",
        "      print('iteration: ',i)\n",
        "      start = time.time()\n",
        "\n",
        "      # chunk of images in the batch\n",
        "      data_row_urls_chunk = data_row_urls[i:i+batch_size]\n",
        "      data_row_ids_chunk = data_row_ids[i:i+batch_size]\n",
        "      \n",
        "      def threading_callback(data_row_url):\n",
        "        return Image.open(requests.get(data_row_url, stream=True).raw).convert('RGB').resize((224, 224))\n",
        "      # download images\n",
        "      with ThreadPoolExecutor() as executer:\n",
        "        imgs = list(executer.map(threading_callback, data_row_urls_chunk))\n",
        "      # process images\n",
        "      img_hf = image_processor(imgs, return_tensors=\"pt\")\n",
        "      # generate resnet embeddings, thanks to inference\n",
        "      with torch.no_grad():\n",
        "        last_layer = model(**img_hf, output_hidden_states=True).last_hidden_state\n",
        "      # max pool to reduce dimensionality\n",
        "      resnet_embeddings = F.adaptive_avg_pool2d(last_layer, (1, 1))\n",
        "      resnet_embeddings = torch.flatten(resnet_embeddings,start_dim=1,end_dim=3) # flatten custom embedding\n",
        "\n",
        "      # convert resnet embeddings, from pytorch to python lists\n",
        "      resnet_embeddings = resnet_embeddings.tolist()\n",
        "\n",
        "      # Store resnet embeddings in NDJson file\n",
        "      payload = []\n",
        "      for (data_row_id,resnet_embedding) in zip(data_row_ids_chunk,resnet_embeddings):\n",
        "        payload.append({\"id\": data_row_id, \"vector\": resnet_embedding})\n",
        "      # store ndjson on disk. it takes take too much memory to keep them all in memory...\n",
        "      # the NDJson file will be the payload for custom embeddings\n",
        "      with open('payload.ndjson', 'w') as f:\n",
        "          ndjson.dump(payload, f)\n",
        "\n",
        "      # Upload the NDJson file to Labelbox\n",
        "      # Note: Please replace 'YOUR_EMBEDDING_ID_HERE' with your actual embedding ID.\n",
        "      # Example command:\n",
        "      # !advtool embeddings import YOUR_EMBEDDING_ID_HERE ./payload.ndjson\n",
        "\n",
        "      !advtool embeddings import a03948c1-151a-4a1a-b948-c1151a6a1a1d ./payload.ndjson\n",
        "\n",
        "      end = time.time()\n",
        "      print('time taken for iteration: ',end-start)\n",
        "\n",
        "  except Exception:\n",
        "      print('error: ', i)\n",
        "      continue  # or you could use 'pass'\n"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Check the upload went well"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# count how many data rows have a specific custom embedding\n",
        "# Note: Please replace 'YOUR_EMBEDDING_ID_HERE' with your actual embedding ID.\n",
        "# Example command:\n",
        "# !advtool embeddings count YOUR_EMBEDDING_ID_HERE ./payload.ndjson\n",
        "\n",
        "!advtool embeddings count a03948c1-151a-4a1a-b948-c1151a6a1a1d"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    }
  ]
}